Simple DDPM¶

In [76]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import math
import matplotlib.pyplot as plt

Cosine Noise Scheduler¶

In [77]:
# total timesteps
T = 1000
In [78]:
# # cosine scheduling (Improved DDPM)
# s = 0.008
# f = torch.cos(torch.pi * (torch.linspace(0, 1, T+1) + s) / (2 + 2*s)) ** 2
# alpha_bar = f / f[0]

# alpha = alpha_bar[1:] / alpha_bar[:-1]
# alpha_bar = alpha_bar[1:]
# beta = 1 - alpha
In [79]:
# linear scheduling
beta_1 = 1e-4   # alpha_1 = 0.9999
beta_T = 0.02   # alpha_T = 0.98

beta = torch.linspace(beta_1, beta_T, T)
alpha = 1.0 - beta
alpha_bar = torch.cumprod(alpha, dim=0)
In [80]:
print(beta.shape)
print(alpha.shape)
print(alpha_bar.shape)
torch.Size([1000])
torch.Size([1000])
torch.Size([1000])

Dataset¶

$$ \sin(2\pi at) + \sin(2\pi bt) \;\; where \;\; a \sim \text{Uniform}(100,110), b \sim \text{Uniform}(1,11) $$

In [81]:
n = 4000    # number of data

a = torch.randint(1000, 1101, (n,), dtype=torch.float32) / 10
b = torch.randint(10, 111, (n,), dtype=torch.float32) / 10

fs = 4000
t = torch.linspace(0, 1, 1*fs)

data = torch.sin(2 * torch.pi * a.reshape(-1,1) @ t.reshape(1,-1)) + torch.sin(2 * torch.pi * b.reshape(-1,1) @ t.reshape(1,-1))
data.shape
Out[81]:
torch.Size([4000, 4000])
In [82]:
# def pattern(d,h,r,k):
#     return 4*(torch.sin(k*d*h/r))**2

# n = 1000    # number of data

# d = torch.randint(200, 301, (n,), dtype=torch.float32) / 10
# h = torch.randint(100, 301, (n,), dtype=torch.float32) / 10
# r = torch.randint(160, 201, (n,), dtype=torch.float32)
# k=torch.linspace(0, 4.2, 5120)

# t = torch.linspace(0, 2560, 5120)
# data = pattern(d.reshape(-1,1),h.reshape(-1,1),r.reshape(-1,1),k)
# data.shape
In [83]:
# data examples
for i in range(10):
    plt.figure(figsize=(12,5))
    plt.plot(t, data[i*10])
    plt.grid()
    plt.show()
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
In [84]:
data = data.unsqueeze(dim=1)
data.shape
Out[84]:
torch.Size([4000, 1, 4000])
In [85]:
class MySignalDataset(Dataset):
    def __init__(self, data):
        """
        data: (num_signals, channel=1, signal_length) 형태의 텐서
        """
        self.data = data

    def __len__(self):
        return self.data.size(0)  # 신호의 총 개수 (1000)

    def __getitem__(self, idx):
        """
        idx번째 신호(1D 텐서)를 반환.
        """
        return self.data[idx]

# 1) Dataset 생성
dataset = MySignalDataset(data)

# 2) DataLoader로 배치 단위 생성
#    원하는 batch_size 로 설정하세요 (예: 16)
batch_size = 10
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 3) DataLoader 순회 예시
for i, batch_data in enumerate(dataloader):
    # batch_data.shape == (batch_size, signal_length)
    print(f"Batch {i} shape: {batch_data.shape}")
    # ... 학습 또는 처리 로직 ...
    # break  # 한 번만 보고 싶다면
Batch 0 shape: torch.Size([10, 1, 4000])
Batch 1 shape: torch.Size([10, 1, 4000])
Batch 2 shape: torch.Size([10, 1, 4000])
Batch 3 shape: torch.Size([10, 1, 4000])
Batch 4 shape: torch.Size([10, 1, 4000])
Batch 5 shape: torch.Size([10, 1, 4000])
Batch 6 shape: torch.Size([10, 1, 4000])
Batch 7 shape: torch.Size([10, 1, 4000])
Batch 8 shape: torch.Size([10, 1, 4000])
Batch 9 shape: torch.Size([10, 1, 4000])
Batch 10 shape: torch.Size([10, 1, 4000])
Batch 11 shape: torch.Size([10, 1, 4000])
Batch 12 shape: torch.Size([10, 1, 4000])
Batch 13 shape: torch.Size([10, 1, 4000])
Batch 14 shape: torch.Size([10, 1, 4000])
Batch 15 shape: torch.Size([10, 1, 4000])
Batch 16 shape: torch.Size([10, 1, 4000])
Batch 17 shape: torch.Size([10, 1, 4000])
Batch 18 shape: torch.Size([10, 1, 4000])
Batch 19 shape: torch.Size([10, 1, 4000])
Batch 20 shape: torch.Size([10, 1, 4000])
Batch 21 shape: torch.Size([10, 1, 4000])
Batch 22 shape: torch.Size([10, 1, 4000])
Batch 23 shape: torch.Size([10, 1, 4000])
Batch 24 shape: torch.Size([10, 1, 4000])
Batch 25 shape: torch.Size([10, 1, 4000])
Batch 26 shape: torch.Size([10, 1, 4000])
Batch 27 shape: torch.Size([10, 1, 4000])
Batch 28 shape: torch.Size([10, 1, 4000])
Batch 29 shape: torch.Size([10, 1, 4000])
Batch 30 shape: torch.Size([10, 1, 4000])
Batch 31 shape: torch.Size([10, 1, 4000])
Batch 32 shape: torch.Size([10, 1, 4000])
Batch 33 shape: torch.Size([10, 1, 4000])
Batch 34 shape: torch.Size([10, 1, 4000])
Batch 35 shape: torch.Size([10, 1, 4000])
Batch 36 shape: torch.Size([10, 1, 4000])
Batch 37 shape: torch.Size([10, 1, 4000])
Batch 38 shape: torch.Size([10, 1, 4000])
Batch 39 shape: torch.Size([10, 1, 4000])
Batch 40 shape: torch.Size([10, 1, 4000])
Batch 41 shape: torch.Size([10, 1, 4000])
Batch 42 shape: torch.Size([10, 1, 4000])
Batch 43 shape: torch.Size([10, 1, 4000])
Batch 44 shape: torch.Size([10, 1, 4000])
Batch 45 shape: torch.Size([10, 1, 4000])
Batch 46 shape: torch.Size([10, 1, 4000])
Batch 47 shape: torch.Size([10, 1, 4000])
Batch 48 shape: torch.Size([10, 1, 4000])
Batch 49 shape: torch.Size([10, 1, 4000])
Batch 50 shape: torch.Size([10, 1, 4000])
Batch 51 shape: torch.Size([10, 1, 4000])
Batch 52 shape: torch.Size([10, 1, 4000])
Batch 53 shape: torch.Size([10, 1, 4000])
Batch 54 shape: torch.Size([10, 1, 4000])
Batch 55 shape: torch.Size([10, 1, 4000])
Batch 56 shape: torch.Size([10, 1, 4000])
Batch 57 shape: torch.Size([10, 1, 4000])
Batch 58 shape: torch.Size([10, 1, 4000])
Batch 59 shape: torch.Size([10, 1, 4000])
Batch 60 shape: torch.Size([10, 1, 4000])
Batch 61 shape: torch.Size([10, 1, 4000])
Batch 62 shape: torch.Size([10, 1, 4000])
Batch 63 shape: torch.Size([10, 1, 4000])
Batch 64 shape: torch.Size([10, 1, 4000])
Batch 65 shape: torch.Size([10, 1, 4000])
Batch 66 shape: torch.Size([10, 1, 4000])
Batch 67 shape: torch.Size([10, 1, 4000])
Batch 68 shape: torch.Size([10, 1, 4000])
Batch 69 shape: torch.Size([10, 1, 4000])
Batch 70 shape: torch.Size([10, 1, 4000])
Batch 71 shape: torch.Size([10, 1, 4000])
Batch 72 shape: torch.Size([10, 1, 4000])
Batch 73 shape: torch.Size([10, 1, 4000])
Batch 74 shape: torch.Size([10, 1, 4000])
Batch 75 shape: torch.Size([10, 1, 4000])
Batch 76 shape: torch.Size([10, 1, 4000])
Batch 77 shape: torch.Size([10, 1, 4000])
Batch 78 shape: torch.Size([10, 1, 4000])
Batch 79 shape: torch.Size([10, 1, 4000])
Batch 80 shape: torch.Size([10, 1, 4000])
Batch 81 shape: torch.Size([10, 1, 4000])
Batch 82 shape: torch.Size([10, 1, 4000])
Batch 83 shape: torch.Size([10, 1, 4000])
Batch 84 shape: torch.Size([10, 1, 4000])
Batch 85 shape: torch.Size([10, 1, 4000])
Batch 86 shape: torch.Size([10, 1, 4000])
Batch 87 shape: torch.Size([10, 1, 4000])
Batch 88 shape: torch.Size([10, 1, 4000])
Batch 89 shape: torch.Size([10, 1, 4000])
Batch 90 shape: torch.Size([10, 1, 4000])
Batch 91 shape: torch.Size([10, 1, 4000])
Batch 92 shape: torch.Size([10, 1, 4000])
Batch 93 shape: torch.Size([10, 1, 4000])
Batch 94 shape: torch.Size([10, 1, 4000])
Batch 95 shape: torch.Size([10, 1, 4000])
Batch 96 shape: torch.Size([10, 1, 4000])
Batch 97 shape: torch.Size([10, 1, 4000])
Batch 98 shape: torch.Size([10, 1, 4000])
Batch 99 shape: torch.Size([10, 1, 4000])
Batch 100 shape: torch.Size([10, 1, 4000])
Batch 101 shape: torch.Size([10, 1, 4000])
Batch 102 shape: torch.Size([10, 1, 4000])
Batch 103 shape: torch.Size([10, 1, 4000])
Batch 104 shape: torch.Size([10, 1, 4000])
Batch 105 shape: torch.Size([10, 1, 4000])
Batch 106 shape: torch.Size([10, 1, 4000])
Batch 107 shape: torch.Size([10, 1, 4000])
Batch 108 shape: torch.Size([10, 1, 4000])
Batch 109 shape: torch.Size([10, 1, 4000])
Batch 110 shape: torch.Size([10, 1, 4000])
Batch 111 shape: torch.Size([10, 1, 4000])
Batch 112 shape: torch.Size([10, 1, 4000])
Batch 113 shape: torch.Size([10, 1, 4000])
Batch 114 shape: torch.Size([10, 1, 4000])
Batch 115 shape: torch.Size([10, 1, 4000])
Batch 116 shape: torch.Size([10, 1, 4000])
Batch 117 shape: torch.Size([10, 1, 4000])
Batch 118 shape: torch.Size([10, 1, 4000])
Batch 119 shape: torch.Size([10, 1, 4000])
Batch 120 shape: torch.Size([10, 1, 4000])
Batch 121 shape: torch.Size([10, 1, 4000])
Batch 122 shape: torch.Size([10, 1, 4000])
Batch 123 shape: torch.Size([10, 1, 4000])
Batch 124 shape: torch.Size([10, 1, 4000])
Batch 125 shape: torch.Size([10, 1, 4000])
Batch 126 shape: torch.Size([10, 1, 4000])
Batch 127 shape: torch.Size([10, 1, 4000])
Batch 128 shape: torch.Size([10, 1, 4000])
Batch 129 shape: torch.Size([10, 1, 4000])
Batch 130 shape: torch.Size([10, 1, 4000])
Batch 131 shape: torch.Size([10, 1, 4000])
Batch 132 shape: torch.Size([10, 1, 4000])
Batch 133 shape: torch.Size([10, 1, 4000])
Batch 134 shape: torch.Size([10, 1, 4000])
Batch 135 shape: torch.Size([10, 1, 4000])
Batch 136 shape: torch.Size([10, 1, 4000])
Batch 137 shape: torch.Size([10, 1, 4000])
Batch 138 shape: torch.Size([10, 1, 4000])
Batch 139 shape: torch.Size([10, 1, 4000])
Batch 140 shape: torch.Size([10, 1, 4000])
Batch 141 shape: torch.Size([10, 1, 4000])
Batch 142 shape: torch.Size([10, 1, 4000])
Batch 143 shape: torch.Size([10, 1, 4000])
Batch 144 shape: torch.Size([10, 1, 4000])
Batch 145 shape: torch.Size([10, 1, 4000])
Batch 146 shape: torch.Size([10, 1, 4000])
Batch 147 shape: torch.Size([10, 1, 4000])
Batch 148 shape: torch.Size([10, 1, 4000])
Batch 149 shape: torch.Size([10, 1, 4000])
Batch 150 shape: torch.Size([10, 1, 4000])
Batch 151 shape: torch.Size([10, 1, 4000])
Batch 152 shape: torch.Size([10, 1, 4000])
Batch 153 shape: torch.Size([10, 1, 4000])
Batch 154 shape: torch.Size([10, 1, 4000])
Batch 155 shape: torch.Size([10, 1, 4000])
Batch 156 shape: torch.Size([10, 1, 4000])
Batch 157 shape: torch.Size([10, 1, 4000])
Batch 158 shape: torch.Size([10, 1, 4000])
Batch 159 shape: torch.Size([10, 1, 4000])
Batch 160 shape: torch.Size([10, 1, 4000])
Batch 161 shape: torch.Size([10, 1, 4000])
Batch 162 shape: torch.Size([10, 1, 4000])
Batch 163 shape: torch.Size([10, 1, 4000])
Batch 164 shape: torch.Size([10, 1, 4000])
Batch 165 shape: torch.Size([10, 1, 4000])
Batch 166 shape: torch.Size([10, 1, 4000])
Batch 167 shape: torch.Size([10, 1, 4000])
Batch 168 shape: torch.Size([10, 1, 4000])
Batch 169 shape: torch.Size([10, 1, 4000])
Batch 170 shape: torch.Size([10, 1, 4000])
Batch 171 shape: torch.Size([10, 1, 4000])
Batch 172 shape: torch.Size([10, 1, 4000])
Batch 173 shape: torch.Size([10, 1, 4000])
Batch 174 shape: torch.Size([10, 1, 4000])
Batch 175 shape: torch.Size([10, 1, 4000])
Batch 176 shape: torch.Size([10, 1, 4000])
Batch 177 shape: torch.Size([10, 1, 4000])
Batch 178 shape: torch.Size([10, 1, 4000])
Batch 179 shape: torch.Size([10, 1, 4000])
Batch 180 shape: torch.Size([10, 1, 4000])
Batch 181 shape: torch.Size([10, 1, 4000])
Batch 182 shape: torch.Size([10, 1, 4000])
Batch 183 shape: torch.Size([10, 1, 4000])
Batch 184 shape: torch.Size([10, 1, 4000])
Batch 185 shape: torch.Size([10, 1, 4000])
Batch 186 shape: torch.Size([10, 1, 4000])
Batch 187 shape: torch.Size([10, 1, 4000])
Batch 188 shape: torch.Size([10, 1, 4000])
Batch 189 shape: torch.Size([10, 1, 4000])
Batch 190 shape: torch.Size([10, 1, 4000])
Batch 191 shape: torch.Size([10, 1, 4000])
Batch 192 shape: torch.Size([10, 1, 4000])
Batch 193 shape: torch.Size([10, 1, 4000])
Batch 194 shape: torch.Size([10, 1, 4000])
Batch 195 shape: torch.Size([10, 1, 4000])
Batch 196 shape: torch.Size([10, 1, 4000])
Batch 197 shape: torch.Size([10, 1, 4000])
Batch 198 shape: torch.Size([10, 1, 4000])
Batch 199 shape: torch.Size([10, 1, 4000])
Batch 200 shape: torch.Size([10, 1, 4000])
Batch 201 shape: torch.Size([10, 1, 4000])
Batch 202 shape: torch.Size([10, 1, 4000])
Batch 203 shape: torch.Size([10, 1, 4000])
Batch 204 shape: torch.Size([10, 1, 4000])
Batch 205 shape: torch.Size([10, 1, 4000])
Batch 206 shape: torch.Size([10, 1, 4000])
Batch 207 shape: torch.Size([10, 1, 4000])
Batch 208 shape: torch.Size([10, 1, 4000])
Batch 209 shape: torch.Size([10, 1, 4000])
Batch 210 shape: torch.Size([10, 1, 4000])
Batch 211 shape: torch.Size([10, 1, 4000])
Batch 212 shape: torch.Size([10, 1, 4000])
Batch 213 shape: torch.Size([10, 1, 4000])
Batch 214 shape: torch.Size([10, 1, 4000])
Batch 215 shape: torch.Size([10, 1, 4000])
Batch 216 shape: torch.Size([10, 1, 4000])
Batch 217 shape: torch.Size([10, 1, 4000])
Batch 218 shape: torch.Size([10, 1, 4000])
Batch 219 shape: torch.Size([10, 1, 4000])
Batch 220 shape: torch.Size([10, 1, 4000])
Batch 221 shape: torch.Size([10, 1, 4000])
Batch 222 shape: torch.Size([10, 1, 4000])
Batch 223 shape: torch.Size([10, 1, 4000])
Batch 224 shape: torch.Size([10, 1, 4000])
Batch 225 shape: torch.Size([10, 1, 4000])
Batch 226 shape: torch.Size([10, 1, 4000])
Batch 227 shape: torch.Size([10, 1, 4000])
Batch 228 shape: torch.Size([10, 1, 4000])
Batch 229 shape: torch.Size([10, 1, 4000])
Batch 230 shape: torch.Size([10, 1, 4000])
Batch 231 shape: torch.Size([10, 1, 4000])
Batch 232 shape: torch.Size([10, 1, 4000])
Batch 233 shape: torch.Size([10, 1, 4000])
Batch 234 shape: torch.Size([10, 1, 4000])
Batch 235 shape: torch.Size([10, 1, 4000])
Batch 236 shape: torch.Size([10, 1, 4000])
Batch 237 shape: torch.Size([10, 1, 4000])
Batch 238 shape: torch.Size([10, 1, 4000])
Batch 239 shape: torch.Size([10, 1, 4000])
Batch 240 shape: torch.Size([10, 1, 4000])
Batch 241 shape: torch.Size([10, 1, 4000])
Batch 242 shape: torch.Size([10, 1, 4000])
Batch 243 shape: torch.Size([10, 1, 4000])
Batch 244 shape: torch.Size([10, 1, 4000])
Batch 245 shape: torch.Size([10, 1, 4000])
Batch 246 shape: torch.Size([10, 1, 4000])
Batch 247 shape: torch.Size([10, 1, 4000])
Batch 248 shape: torch.Size([10, 1, 4000])
Batch 249 shape: torch.Size([10, 1, 4000])
Batch 250 shape: torch.Size([10, 1, 4000])
Batch 251 shape: torch.Size([10, 1, 4000])
Batch 252 shape: torch.Size([10, 1, 4000])
Batch 253 shape: torch.Size([10, 1, 4000])
Batch 254 shape: torch.Size([10, 1, 4000])
Batch 255 shape: torch.Size([10, 1, 4000])
Batch 256 shape: torch.Size([10, 1, 4000])
Batch 257 shape: torch.Size([10, 1, 4000])
Batch 258 shape: torch.Size([10, 1, 4000])
Batch 259 shape: torch.Size([10, 1, 4000])
Batch 260 shape: torch.Size([10, 1, 4000])
Batch 261 shape: torch.Size([10, 1, 4000])
Batch 262 shape: torch.Size([10, 1, 4000])
Batch 263 shape: torch.Size([10, 1, 4000])
Batch 264 shape: torch.Size([10, 1, 4000])
Batch 265 shape: torch.Size([10, 1, 4000])
Batch 266 shape: torch.Size([10, 1, 4000])
Batch 267 shape: torch.Size([10, 1, 4000])
Batch 268 shape: torch.Size([10, 1, 4000])
Batch 269 shape: torch.Size([10, 1, 4000])
Batch 270 shape: torch.Size([10, 1, 4000])
Batch 271 shape: torch.Size([10, 1, 4000])
Batch 272 shape: torch.Size([10, 1, 4000])
Batch 273 shape: torch.Size([10, 1, 4000])
Batch 274 shape: torch.Size([10, 1, 4000])
Batch 275 shape: torch.Size([10, 1, 4000])
Batch 276 shape: torch.Size([10, 1, 4000])
Batch 277 shape: torch.Size([10, 1, 4000])
Batch 278 shape: torch.Size([10, 1, 4000])
Batch 279 shape: torch.Size([10, 1, 4000])
Batch 280 shape: torch.Size([10, 1, 4000])
Batch 281 shape: torch.Size([10, 1, 4000])
Batch 282 shape: torch.Size([10, 1, 4000])
Batch 283 shape: torch.Size([10, 1, 4000])
Batch 284 shape: torch.Size([10, 1, 4000])
Batch 285 shape: torch.Size([10, 1, 4000])
Batch 286 shape: torch.Size([10, 1, 4000])
Batch 287 shape: torch.Size([10, 1, 4000])
Batch 288 shape: torch.Size([10, 1, 4000])
Batch 289 shape: torch.Size([10, 1, 4000])
Batch 290 shape: torch.Size([10, 1, 4000])
Batch 291 shape: torch.Size([10, 1, 4000])
Batch 292 shape: torch.Size([10, 1, 4000])
Batch 293 shape: torch.Size([10, 1, 4000])
Batch 294 shape: torch.Size([10, 1, 4000])
Batch 295 shape: torch.Size([10, 1, 4000])
Batch 296 shape: torch.Size([10, 1, 4000])
Batch 297 shape: torch.Size([10, 1, 4000])
Batch 298 shape: torch.Size([10, 1, 4000])
Batch 299 shape: torch.Size([10, 1, 4000])
Batch 300 shape: torch.Size([10, 1, 4000])
Batch 301 shape: torch.Size([10, 1, 4000])
Batch 302 shape: torch.Size([10, 1, 4000])
Batch 303 shape: torch.Size([10, 1, 4000])
Batch 304 shape: torch.Size([10, 1, 4000])
Batch 305 shape: torch.Size([10, 1, 4000])
Batch 306 shape: torch.Size([10, 1, 4000])
Batch 307 shape: torch.Size([10, 1, 4000])
Batch 308 shape: torch.Size([10, 1, 4000])
Batch 309 shape: torch.Size([10, 1, 4000])
Batch 310 shape: torch.Size([10, 1, 4000])
Batch 311 shape: torch.Size([10, 1, 4000])
Batch 312 shape: torch.Size([10, 1, 4000])
Batch 313 shape: torch.Size([10, 1, 4000])
Batch 314 shape: torch.Size([10, 1, 4000])
Batch 315 shape: torch.Size([10, 1, 4000])
Batch 316 shape: torch.Size([10, 1, 4000])
Batch 317 shape: torch.Size([10, 1, 4000])
Batch 318 shape: torch.Size([10, 1, 4000])
Batch 319 shape: torch.Size([10, 1, 4000])
Batch 320 shape: torch.Size([10, 1, 4000])
Batch 321 shape: torch.Size([10, 1, 4000])
Batch 322 shape: torch.Size([10, 1, 4000])
Batch 323 shape: torch.Size([10, 1, 4000])
Batch 324 shape: torch.Size([10, 1, 4000])
Batch 325 shape: torch.Size([10, 1, 4000])
Batch 326 shape: torch.Size([10, 1, 4000])
Batch 327 shape: torch.Size([10, 1, 4000])
Batch 328 shape: torch.Size([10, 1, 4000])
Batch 329 shape: torch.Size([10, 1, 4000])
Batch 330 shape: torch.Size([10, 1, 4000])
Batch 331 shape: torch.Size([10, 1, 4000])
Batch 332 shape: torch.Size([10, 1, 4000])
Batch 333 shape: torch.Size([10, 1, 4000])
Batch 334 shape: torch.Size([10, 1, 4000])
Batch 335 shape: torch.Size([10, 1, 4000])
Batch 336 shape: torch.Size([10, 1, 4000])
Batch 337 shape: torch.Size([10, 1, 4000])
Batch 338 shape: torch.Size([10, 1, 4000])
Batch 339 shape: torch.Size([10, 1, 4000])
Batch 340 shape: torch.Size([10, 1, 4000])
Batch 341 shape: torch.Size([10, 1, 4000])
Batch 342 shape: torch.Size([10, 1, 4000])
Batch 343 shape: torch.Size([10, 1, 4000])
Batch 344 shape: torch.Size([10, 1, 4000])
Batch 345 shape: torch.Size([10, 1, 4000])
Batch 346 shape: torch.Size([10, 1, 4000])
Batch 347 shape: torch.Size([10, 1, 4000])
Batch 348 shape: torch.Size([10, 1, 4000])
Batch 349 shape: torch.Size([10, 1, 4000])
Batch 350 shape: torch.Size([10, 1, 4000])
Batch 351 shape: torch.Size([10, 1, 4000])
Batch 352 shape: torch.Size([10, 1, 4000])
Batch 353 shape: torch.Size([10, 1, 4000])
Batch 354 shape: torch.Size([10, 1, 4000])
Batch 355 shape: torch.Size([10, 1, 4000])
Batch 356 shape: torch.Size([10, 1, 4000])
Batch 357 shape: torch.Size([10, 1, 4000])
Batch 358 shape: torch.Size([10, 1, 4000])
Batch 359 shape: torch.Size([10, 1, 4000])
Batch 360 shape: torch.Size([10, 1, 4000])
Batch 361 shape: torch.Size([10, 1, 4000])
Batch 362 shape: torch.Size([10, 1, 4000])
Batch 363 shape: torch.Size([10, 1, 4000])
Batch 364 shape: torch.Size([10, 1, 4000])
Batch 365 shape: torch.Size([10, 1, 4000])
Batch 366 shape: torch.Size([10, 1, 4000])
Batch 367 shape: torch.Size([10, 1, 4000])
Batch 368 shape: torch.Size([10, 1, 4000])
Batch 369 shape: torch.Size([10, 1, 4000])
Batch 370 shape: torch.Size([10, 1, 4000])
Batch 371 shape: torch.Size([10, 1, 4000])
Batch 372 shape: torch.Size([10, 1, 4000])
Batch 373 shape: torch.Size([10, 1, 4000])
Batch 374 shape: torch.Size([10, 1, 4000])
Batch 375 shape: torch.Size([10, 1, 4000])
Batch 376 shape: torch.Size([10, 1, 4000])
Batch 377 shape: torch.Size([10, 1, 4000])
Batch 378 shape: torch.Size([10, 1, 4000])
Batch 379 shape: torch.Size([10, 1, 4000])
Batch 380 shape: torch.Size([10, 1, 4000])
Batch 381 shape: torch.Size([10, 1, 4000])
Batch 382 shape: torch.Size([10, 1, 4000])
Batch 383 shape: torch.Size([10, 1, 4000])
Batch 384 shape: torch.Size([10, 1, 4000])
Batch 385 shape: torch.Size([10, 1, 4000])
Batch 386 shape: torch.Size([10, 1, 4000])
Batch 387 shape: torch.Size([10, 1, 4000])
Batch 388 shape: torch.Size([10, 1, 4000])
Batch 389 shape: torch.Size([10, 1, 4000])
Batch 390 shape: torch.Size([10, 1, 4000])
Batch 391 shape: torch.Size([10, 1, 4000])
Batch 392 shape: torch.Size([10, 1, 4000])
Batch 393 shape: torch.Size([10, 1, 4000])
Batch 394 shape: torch.Size([10, 1, 4000])
Batch 395 shape: torch.Size([10, 1, 4000])
Batch 396 shape: torch.Size([10, 1, 4000])
Batch 397 shape: torch.Size([10, 1, 4000])
Batch 398 shape: torch.Size([10, 1, 4000])
Batch 399 shape: torch.Size([10, 1, 4000])
In [86]:
def exists(x):
    return x is not None


class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim, theta = 10000):
        super().__init__()
        self.dim = dim
        self.theta = theta

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(self.theta) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb


def Upsample(dim, dim_out):
    return nn.Sequential(
        nn.Upsample(scale_factor=2, mode='nearest'),
        nn.Conv1d(dim, dim_out, 3, padding=1)
    )


def Downsample(dim, dim_out):
    return nn.Conv1d(dim, dim_out, 4, 2, 1)


class RMSNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.scale = dim ** 0.5
        self.g = nn.Parameter(torch.ones(1, dim, 1))

    def forward(self, x):
        return F.normalize(x, dim=1) * self.g * self.scale
    

class Block(nn.Module):
    def __init__(self, dim, dim_out, dropout=0.):
        super().__init__()
        self.proj = nn.Conv1d(dim, dim_out, 3, padding=1)
        self.norm = RMSNorm(dim_out)
        self.act = nn.SiLU()
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, scale_shift=None):
        x = self.proj(x)
        x = self.norm(x)

        if exists(scale_shift):
            scale, shift = scale_shift
            x = x * (scale + 1) + shift         # scaling & shifting using time emb
        
        x = self.act(x)
        return self.dropout(x)
    

class ResnetBlock(nn.Module):
    def __init__(self, dim, dim_out, *, time_emb_dim=None, dropout=0.):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.SiLU(),
            nn.Linear(time_emb_dim, dim_out * 2)
        ) if exists(time_emb_dim) else None

        self.block1 = Block(dim, dim_out, dropout=dropout)
        self.block2 = Block(dim_out, dim_out)
        self.res_conv = nn.Conv1d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

    def forward(self, x, time_emb=None):
        scale_shift = None
        if exists(self.mlp) and exists(time_emb):
            time_emb = self.mlp(time_emb)
            time_emb = time_emb.unsqueeze(2)
            scale_shift = time_emb.chunk(2, dim=1)

        h = self.block1(x, scale_shift=scale_shift)
        h = self.block2(h)
        return h + self.res_conv(x)


class LinearAttention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv1d(dim, hidden_dim * 3, 1, bias=False)

        self.to_out = nn.Sequential(
            nn.Conv1d(hidden_dim, dim, 1),
            RMSNorm(dim)
        )

    def forward(self, x):
        B, C, L = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=1)
        q, k, v = map(lambda t: t.reshape(B, self.heads, -1, L), qkv)

        q = q.softmax(dim=-2)   # [B, h, d, L]
        k = k.softmax(dim=-1)   # [B, h, d, L]
        # v.shape = [B, h, e, L]

        q = q * self.scale

        context = torch.matmul(k, v.transpose(-1,-2))  # [B, h, d, e]
        
        out = torch.matmul(context.transpose(-1,-2), q) # [B, h, d, L]
        out = out.reshape(B, -1, L)
        return self.to_out(out)


class Attention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        self.dim_head = dim_head
        hidden_dim = dim_head * heads

        self.to_qkv = nn.Conv1d(dim, hidden_dim * 3, 1, bias=False)
        self.to_out = nn.Conv1d(hidden_dim, dim, 1)

    def forward(self, x):
        B, C, L = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=1)
        q, k, v = map(lambda t: t.reshape(B, self.heads, -1, L), qkv)
        # q.shape = [B, h, d, I], k.shape = [B, h, d, L], v.shape = [B, h, d, L]
        q = q * self.scale

        sim = torch.matmul(q.transpose(-1,-2), k)   # [B, h, I, L]
        attn = sim.softmax(dim=-1)
        out = torch.matmul(attn, v.transpose(-1,-2))    # [B, h, I, d]

        out = out.reshape(B, -1, L)
        return self.to_out(out)


class LinearAttentionBlock(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.norm = RMSNorm(dim)
        self.fc = LinearAttention(dim, heads, dim_head)

    def forward(self, x):
        h = self.norm(x)
        h = self.fc(h)
        return h + x
    

class AttentionBlock(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.norm = RMSNorm(dim)
        self.fc = Attention(dim, heads, dim_head)

    def forward(self, x):
        h = self.norm(x)
        h = self.fc(h)
        return h + x
    

class DownBlock(nn.Module):
    def __init__(self, dim, dim_in, dim_out, heads=4, dim_head=32, dropout=0., last=False):
        super().__init__()
        self.time_dim = dim * 4

        self.resnetblock1 = ResnetBlock(dim_in, dim_in, time_emb_dim=self.time_dim, dropout=dropout)
        self.resnetblock2 = ResnetBlock(dim_in, dim_in, time_emb_dim=self.time_dim, dropout=dropout)
        self.linattnblock = LinearAttentionBlock(dim_in, heads, dim_head)
        self.downsample = Downsample(dim_in, dim_out) if not last else nn.Conv1d(dim_in, dim_out, 3, padding=1)

    def forward(self, x, t):
        h1 = self.resnetblock1(x, t)
        h2 = self.resnetblock2(h1, t)
        h2 = self.linattnblock(h2)
        out = self.downsample(h2)
        return out, h1, h2


class MidBlock(nn.Module):
    def __init__(self, dim, mid_dim, heads=4, dim_head=32, dropout=0.):
        super().__init__()
        self.time_dim = dim * 4

        self.resnetblock1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=self.time_dim, dropout=dropout)
        self.attnblock = AttentionBlock(mid_dim, heads, dim_head)
        self.resnetblock2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=self.time_dim, dropout=dropout)

    def forward(self, x, t):
        x = self.resnetblock1(x, t)
        x = self.attnblock(x)
        return self.resnetblock2(x, t)
    

class UpBlock(nn.Module):
    def __init__(self, dim, dim_in, dim_out, heads=4, dim_head=32, dropout=0., last=False):
        super().__init__()
        self.time_dim = dim * 4

        self.resnetblock1 = ResnetBlock(dim_in + dim_out, dim_in, time_emb_dim=self.time_dim, dropout=dropout)
        self.resnetblock2 = ResnetBlock(dim_in + dim_out, dim_in, time_emb_dim=self.time_dim, dropout=dropout)
        self.linattnblock = LinearAttentionBlock(dim_in, heads, dim_head)
        self.upsample = Upsample(dim_in, dim_out) if not last else nn.Conv1d(dim_in, dim_out, 3, padding=1)

    def forward(self, x, h1, h2, t):
        x = self.resnetblock1(torch.cat((x, h1), dim=1), t)
        x = self.resnetblock2(torch.cat((x, h2), dim=1), t)
        x = self.linattnblock(x)
        return self.upsample(x)


class Unet(nn.Module):
    def __init__(self, dim=16):
        super(Unet, self).__init__()
        time_dim = dim * 4

        self.init_conv = nn.Conv1d(1, dim, 7, padding=3)

        self.time_mlp = nn.Sequential(
            SinusoidalPosEmb(dim),
            nn.Linear(dim, time_dim),
            nn.GELU(),
            nn.Linear(time_dim, time_dim)
        )

        self.down1 = DownBlock(dim, dim_in = dim, dim_out = dim*2)
        self.down2 = DownBlock(dim, dim_in = dim*2, dim_out = dim*4)
        self.down3 = DownBlock(dim, dim_in = dim*4, dim_out = dim*8)
        self.down4 = DownBlock(dim, dim_in = dim*8, dim_out = dim*16, last=True)
        
        self.mid = MidBlock(dim, mid_dim = dim*16)

        self.up1 = UpBlock(dim, dim_in = dim*16, dim_out = dim*8)
        self.up2 = UpBlock(dim, dim_in = dim*8, dim_out = dim*4)
        self.up3 = UpBlock(dim, dim_in = dim*4, dim_out = dim*2)
        self.up4 = UpBlock(dim, dim_in = dim*2, dim_out = dim, last=True)

        self.final_res_block = ResnetBlock(dim*2, dim)
        self.final_conv = nn.Conv1d(dim, 1, 1)

    def forward(self, x, time):
        r = self.init_conv(x)

        t = self.time_mlp(time)

        x, h1, h2 = self.down1(r, t)
        x, h3, h4 = self.down2(x, t)
        x, h5, h6 = self.down3(x, t)
        x, h7, h8 = self.down4(x, t)

        x = self.mid(x, t)

        x = self.up1(x, h7, h8, t)
        x = self.up2(x, h5, h6, t)
        x = self.up3(x, h3, h4, t)
        x = self.up4(x, h1, h2, t)

        x = torch.cat((x, r), dim=1)
        x = self.final_res_block(x, t)
        return self.final_conv(x)
In [87]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
cuda
In [88]:
def train_ddpm(model, dataloader, optimizer, epochs=50):
    model.train()
    model.to(device)

    for epoch in range(epochs):
        total_loss = 0

        for step, x0 in enumerate(dataloader):
            # x0.shape = [batch_size, 1, signal_length]
            x0 = x0.to(device)

            # 1) t를 uniform random으로 뽑기
            #    t range: [1, T], 실제 구현에서는 [0, T-1]도 가능
            batch_size = x0.shape[0]
            t = torch.randint(
                low=1, high=T+1, size=(batch_size,), device=device
            )  # t in [1..T]

            # 2) 노이즈 epsilon 샘플
            epsilon = torch.randn_like(x0)

            # 3) alpha_bar_t 가져오기
            #    t-1 인덱스로 indexing (파이썬은 0-based, t는 1-based)
            alpha_bar_t = alpha_bar.to(device)[t-1].reshape(batch_size, 1, 1)
            
            # 4) x_t 생성
            #    x_t = sqrt(alpha_bar[t]) * x0 + sqrt(1 - alpha_bar[t]) * epsilon
            sqrt_alpha_bar_t = torch.sqrt(alpha_bar_t)
            sqrt_one_minus_alpha_bar_t = torch.sqrt(1. - alpha_bar_t)
            x_t = sqrt_alpha_bar_t * x0 + sqrt_one_minus_alpha_bar_t * epsilon

            # 5) 모델 추론: 모델은 x_t와 t를 입력받아 epsilon 예측
            epsilon_pred = model(x_t, t)  # shape: same as x0

            # 6) MSE Loss
            loss = F.mse_loss(epsilon_pred, epsilon)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch [{epoch+1}/{epochs}] - Loss: {total_loss / len(dataloader):.4f}")
In [89]:
model = Unet()
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)

# 훈련
train_ddpm(model, dataloader, optimizer, epochs=200)
Epoch [1/200] - Loss: 0.1021
Epoch [2/200] - Loss: 0.0219
Epoch [3/200] - Loss: 0.0155
Epoch [4/200] - Loss: 0.0118
Epoch [5/200] - Loss: 0.0106
Epoch [6/200] - Loss: 0.0095
Epoch [7/200] - Loss: 0.0093
Epoch [8/200] - Loss: 0.0090
Epoch [9/200] - Loss: 0.0081
Epoch [10/200] - Loss: 0.0094
Epoch [11/200] - Loss: 0.0080
Epoch [12/200] - Loss: 0.0077
Epoch [13/200] - Loss: 0.0075
Epoch [14/200] - Loss: 0.0071
Epoch [15/200] - Loss: 0.0075
Epoch [16/200] - Loss: 0.0071
Epoch [17/200] - Loss: 0.0093
Epoch [18/200] - Loss: 0.0067
Epoch [19/200] - Loss: 0.0065
Epoch [20/200] - Loss: 0.0067
Epoch [21/200] - Loss: 0.0065
Epoch [22/200] - Loss: 0.0065
Epoch [23/200] - Loss: 0.0065
Epoch [24/200] - Loss: 0.0072
Epoch [25/200] - Loss: 0.0062
Epoch [26/200] - Loss: 0.0062
Epoch [27/200] - Loss: 0.0063
Epoch [28/200] - Loss: 0.0061
Epoch [29/200] - Loss: 0.0061
Epoch [30/200] - Loss: 0.0059
Epoch [31/200] - Loss: 0.0058
Epoch [32/200] - Loss: 0.0063
Epoch [33/200] - Loss: 0.0064
Epoch [34/200] - Loss: 0.0057
Epoch [35/200] - Loss: 0.0057
Epoch [36/200] - Loss: 0.0055
Epoch [37/200] - Loss: 0.0056
Epoch [38/200] - Loss: 0.0058
Epoch [39/200] - Loss: 0.0062
Epoch [40/200] - Loss: 0.0054
Epoch [41/200] - Loss: 0.0054
Epoch [42/200] - Loss: 0.0057
Epoch [43/200] - Loss: 0.0054
Epoch [44/200] - Loss: 0.0052
Epoch [45/200] - Loss: 0.0054
Epoch [46/200] - Loss: 0.0052
Epoch [47/200] - Loss: 0.0051
Epoch [48/200] - Loss: 0.0050
Epoch [49/200] - Loss: 0.0050
Epoch [50/200] - Loss: 0.0051
Epoch [51/200] - Loss: 0.0051
Epoch [52/200] - Loss: 0.0050
Epoch [53/200] - Loss: 0.0052
Epoch [54/200] - Loss: 0.0049
Epoch [55/200] - Loss: 0.0047
Epoch [56/200] - Loss: 0.0048
Epoch [57/200] - Loss: 0.0048
Epoch [58/200] - Loss: 0.0047
Epoch [59/200] - Loss: 0.0049
Epoch [60/200] - Loss: 0.0045
Epoch [61/200] - Loss: 0.0047
Epoch [62/200] - Loss: 0.0045
Epoch [63/200] - Loss: 0.0047
Epoch [64/200] - Loss: 0.0045
Epoch [65/200] - Loss: 0.0042
Epoch [66/200] - Loss: 0.0044
Epoch [67/200] - Loss: 0.0044
Epoch [68/200] - Loss: 0.0042
Epoch [69/200] - Loss: 0.0044
Epoch [70/200] - Loss: 0.0043
Epoch [71/200] - Loss: 0.0041
Epoch [72/200] - Loss: 0.0041
Epoch [73/200] - Loss: 0.0041
Epoch [74/200] - Loss: 0.0040
Epoch [75/200] - Loss: 0.0039
Epoch [76/200] - Loss: 0.0040
Epoch [77/200] - Loss: 0.0040
Epoch [78/200] - Loss: 0.0043
Epoch [79/200] - Loss: 0.0038
Epoch [80/200] - Loss: 0.0038
Epoch [81/200] - Loss: 0.0037
Epoch [82/200] - Loss: 0.0037
Epoch [83/200] - Loss: 0.0037
Epoch [84/200] - Loss: 0.0036
Epoch [85/200] - Loss: 0.0036
Epoch [86/200] - Loss: 0.0035
Epoch [87/200] - Loss: 0.0034
Epoch [88/200] - Loss: 0.0033
Epoch [89/200] - Loss: 0.0034
Epoch [90/200] - Loss: 0.0032
Epoch [91/200] - Loss: 0.0032
Epoch [92/200] - Loss: 0.0032
Epoch [93/200] - Loss: 0.0031
Epoch [94/200] - Loss: 0.0030
Epoch [95/200] - Loss: 0.0030
Epoch [96/200] - Loss: 0.0029
Epoch [97/200] - Loss: 0.0030
Epoch [98/200] - Loss: 0.0029
Epoch [99/200] - Loss: 0.0029
Epoch [100/200] - Loss: 0.0029
Epoch [101/200] - Loss: 0.0060
Epoch [102/200] - Loss: 0.0028
Epoch [103/200] - Loss: 0.0028
Epoch [104/200] - Loss: 0.0027
Epoch [105/200] - Loss: 0.0028
Epoch [106/200] - Loss: 0.0027
Epoch [107/200] - Loss: 0.0028
Epoch [108/200] - Loss: 0.0026
Epoch [109/200] - Loss: 0.0027
Epoch [110/200] - Loss: 0.0026
Epoch [111/200] - Loss: 0.0025
Epoch [112/200] - Loss: 0.0026
Epoch [113/200] - Loss: 0.0026
Epoch [114/200] - Loss: 0.0025
Epoch [115/200] - Loss: 0.0025
Epoch [116/200] - Loss: 0.0025
Epoch [117/200] - Loss: 0.0026
Epoch [118/200] - Loss: 0.0025
Epoch [119/200] - Loss: 0.0025
Epoch [120/200] - Loss: 0.0025
Epoch [121/200] - Loss: 0.0025
Epoch [122/200] - Loss: 0.0025
Epoch [123/200] - Loss: 0.0024
Epoch [124/200] - Loss: 0.0025
Epoch [125/200] - Loss: 0.0024
Epoch [126/200] - Loss: 0.0024
Epoch [127/200] - Loss: 0.0023
Epoch [128/200] - Loss: 0.0023
Epoch [129/200] - Loss: 0.0024
Epoch [130/200] - Loss: 0.0023
Epoch [131/200] - Loss: 0.0024
Epoch [132/200] - Loss: 0.0022
Epoch [133/200] - Loss: 0.0022
Epoch [134/200] - Loss: 0.0023
Epoch [135/200] - Loss: 0.0023
Epoch [136/200] - Loss: 0.0022
Epoch [137/200] - Loss: 0.0023
Epoch [138/200] - Loss: 0.0022
Epoch [139/200] - Loss: 0.0022
Epoch [140/200] - Loss: 0.0022
Epoch [141/200] - Loss: 0.0022
Epoch [142/200] - Loss: 0.0021
Epoch [143/200] - Loss: 0.0021
Epoch [144/200] - Loss: 0.0021
Epoch [145/200] - Loss: 0.0021
Epoch [146/200] - Loss: 0.0022
Epoch [147/200] - Loss: 0.0021
Epoch [148/200] - Loss: 0.0020
Epoch [149/200] - Loss: 0.0021
Epoch [150/200] - Loss: 0.0021
Epoch [151/200] - Loss: 0.0021
Epoch [152/200] - Loss: 0.0021
Epoch [153/200] - Loss: 0.0021
Epoch [154/200] - Loss: 0.0023
Epoch [155/200] - Loss: 0.0023
Epoch [156/200] - Loss: 0.0020
Epoch [157/200] - Loss: 0.0020
Epoch [158/200] - Loss: 0.0020
Epoch [159/200] - Loss: 0.0020
Epoch [160/200] - Loss: 0.0019
Epoch [161/200] - Loss: 0.0021
Epoch [162/200] - Loss: 0.0019
Epoch [163/200] - Loss: 0.0019
Epoch [164/200] - Loss: 0.0019
Epoch [165/200] - Loss: 0.0020
Epoch [166/200] - Loss: 0.0023
Epoch [167/200] - Loss: 0.0019
Epoch [168/200] - Loss: 0.0019
Epoch [169/200] - Loss: 0.0021
Epoch [170/200] - Loss: 0.0018
Epoch [171/200] - Loss: 0.0019
Epoch [172/200] - Loss: 0.0018
Epoch [173/200] - Loss: 0.0019
Epoch [174/200] - Loss: 0.0019
Epoch [175/200] - Loss: 0.0019
Epoch [176/200] - Loss: 0.0018
Epoch [177/200] - Loss: 0.0018
Epoch [178/200] - Loss: 0.0018
Epoch [179/200] - Loss: 0.0019
Epoch [180/200] - Loss: 0.0019
Epoch [181/200] - Loss: 0.0018
Epoch [182/200] - Loss: 0.0018
Epoch [183/200] - Loss: 0.0018
Epoch [184/200] - Loss: 0.0018
Epoch [185/200] - Loss: 0.0019
Epoch [186/200] - Loss: 0.0019
Epoch [187/200] - Loss: 0.0024
Epoch [188/200] - Loss: 0.0018
Epoch [189/200] - Loss: 0.0017
Epoch [190/200] - Loss: 0.0018
Epoch [191/200] - Loss: 0.0017
Epoch [192/200] - Loss: 0.0018
Epoch [193/200] - Loss: 0.0018
Epoch [194/200] - Loss: 0.0018
Epoch [195/200] - Loss: 0.0017
Epoch [196/200] - Loss: 0.0018
Epoch [197/200] - Loss: 0.0018
Epoch [198/200] - Loss: 0.0018
Epoch [199/200] - Loss: 0.0017
Epoch [200/200] - Loss: 0.0018

DDPM Sampling¶

In [106]:
@torch.no_grad()
def sample_ddpm(model, num_samples=1, signal_length=10*fs):
    """
    DDPM reverse diffusion sampling
    """
    model.eval()
    model.to(device)

    # 1) x_T ~ N(0, I)
    x_t = torch.randn(num_samples, 1, signal_length).to(device)

    for i in reversed(range(T)): # i: T-1 down to 0
        # i는 파이썬 인덱스, 실제 t는 i+1
        t_val = torch.tensor([i+1]*num_samples, device=device)  # shape = [num_samples]
        
        if (i+1)%100 == 0 :
            plt.figure(figsize=(120,5))

            for j in range(num_samples):
                signal = x_t[j,0].to('cpu').numpy()

                plt.subplot(1, 10, j+1)
                plt.title(f'X_{(i+1)} Signal')
                plt.plot(signal)
                plt.ylim(-4, 4)
                
            plt.show()

        # 모델의 예측 노이즈
        eps = model(x_t, t_val)

        sigma_t = torch.sqrt(beta[i])
        z = torch.randn(num_samples, 1, signal_length).to(device)

        if i == 0: sigma_t = 0

        alpha_t = alpha[i]
        alpha_bar_t = alpha_bar[i]
        # (주의) alpha_bar[i]는 t=i+1에 해당

        # 2) 역방향 공식
        # x_{t-1} = 1/sqrt(alpha_t) ( x_t - (1-alpha_t)/sqrt(1-alpha_bar_t)* eps )
        #           + sigma_t * z (if we add noise)
        # 여기서는 단순하게 sigma=0 가정
        one_over_sqrt_alpha_t = 1.0 / torch.sqrt(alpha_t)
        one_minus_alpha_t = 1.0 - alpha_t
        sqrt_one_minus_alpha_bar_t = torch.sqrt(1.0 - alpha_bar_t)

        # reshape to match (batch_size, 1, 1)
        one_over_sqrt_alpha_t = one_over_sqrt_alpha_t.view(1,1,1).to(device)
        one_minus_alpha_t = one_minus_alpha_t.view(1,1,1).to(device)
        sqrt_one_minus_alpha_bar_t = sqrt_one_minus_alpha_bar_t.view(1,1,1).to(device)

        x_prev = one_over_sqrt_alpha_t * (x_t - (one_minus_alpha_t / sqrt_one_minus_alpha_bar_t) * eps) + sigma_t * z

        x_t = x_prev  # update

    # 최종 x_0 반환
    return x_t
In [107]:
# 샘플링 예시
num_samples = 4
samples = sample_ddpm(model, num_samples=num_samples, signal_length=1*fs)
t = torch.linspace(0, 1, 1*fs)

print("samples shape:", samples.shape)


for i in range(num_samples):
    sample = samples[i,0].to('cpu')
    
    w = torch.linspace(-fs/2, fs/2, 1*fs)
    fft = torch.fft.fft(sample)
    fft_power = torch.abs(torch.fft.fftshift(fft))

    plt.figure(figsize=(25,5))

    plt.subplot(1, 2, 1)
    plt.plot(t, sample)
    plt.ylim(-4, 4)
    plt.grid()

    plt.subplot(1, 2, 2)
    plt.plot(w, fft_power)
    plt.xlim(0, 120)
    plt.grid()

    plt.show()
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
samples shape: torch.Size([4, 1, 4000])
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image

$$ X_t = \sqrt{\bar{\alpha}_t}X_0 + \sqrt{1-\bar{\alpha}_t}\varepsilon \;\;\; where \;\; \varepsilon \sim N(0,1) $$

In [108]:
@torch.no_grad()
def partial_diffusion(model, X0, Lambda=200, num_samples=1, signal_length=10*fs):
    """
    DDPM reverse partial diffusion sampling
    Lambda: partial diffusion time
    """
    model.eval()
    model.to(device)

    # 1) x_t = sqrt(alpha_bar_t)*X_0 + sqrt(1 - alpha_bar_t)*eps
    sqrt_alpha_bar_Lambda = torch.sqrt(alpha_bar[Lambda-1]).view(1,1,1).to(device)
    sqrt_one_minus_alpha_bar_Lambda = torch.sqrt(1.0 - alpha_bar[Lambda-1]).view(1,1,1).to(device)
    eps = torch.randn(num_samples, 1, signal_length).to(device)
    x_t = sqrt_alpha_bar_Lambda @ X0 + sqrt_one_minus_alpha_bar_Lambda @ eps

    for i in reversed(range(Lambda)): # i: Lambda-1 down to 0
        # i는 파이썬 인덱스, 실제 t는 i+1
        t_val = torch.tensor([i+1]*num_samples, device=device)  # shape = [num_samples]
        
        if (i+1) % (Lambda//10) == 0 :
            plt.figure(figsize=(120,5))

            for j in range(num_samples):
                signal = x_t[j,0].to('cpu').numpy()

                plt.subplot(1, 10, j+1)
                plt.title(f'X_{(i+1)} Signal')
                plt.plot(signal)
                plt.ylim(-4, 4)
                
            plt.show()

        # 모델의 예측 노이즈
        eps = model(x_t, t_val)

        sigma_t = torch.sqrt(beta[i])
        z = torch.randn(num_samples, 1, signal_length).to(device)

        if i == 0: sigma_t = 0

        alpha_t = alpha[i]
        alpha_bar_t = alpha_bar[i]
        # (주의) alpha_bar[i]는 t=i+1에 해당

        # 2) 역방향 공식
        # x_{t-1} = 1/sqrt(alpha_t) ( x_t - (1-alpha_t)/sqrt(1-alpha_bar_t)* eps )
        #           + sigma_t * z (if we add noise)
        one_over_sqrt_alpha_t = 1.0 / torch.sqrt(alpha_t)
        one_minus_alpha_t = 1.0 - alpha_t
        sqrt_one_minus_alpha_bar_t = torch.sqrt(1.0 - alpha_bar_t)

        # reshape to match (batch_size, 1, 1)
        one_over_sqrt_alpha_t = one_over_sqrt_alpha_t.view(1,1,1).to(device)
        one_minus_alpha_t = one_minus_alpha_t.view(1,1,1).to(device)
        sqrt_one_minus_alpha_bar_t = sqrt_one_minus_alpha_bar_t.view(1,1,1).to(device)

        x_prev = one_over_sqrt_alpha_t * (x_t - (one_minus_alpha_t / sqrt_one_minus_alpha_bar_t) * eps) + sigma_t * z

        x_t = x_prev  # update

    # 최종 x_0 반환
    return x_t
In [109]:
# partial diffusion 예시
num_samples = 1
Lambda = 200
fs = 4000
t = torch.linspace(0, 1, 1*fs)
original = torch.sin(2 * torch.pi * t).reshape(1, 1, -1).to(device)

samples = partial_diffusion(model, original, Lambda=Lambda, num_samples=num_samples, signal_length=1*fs)
print("samples shape:", samples.shape)  # [num_samples, 1, 5*fs]

org_signal = original.reshape(-1).to('cpu')


w = torch.linspace(-fs/2, fs/2, 1*fs)
fft = torch.fft.fft(org_signal)
fft_power = torch.abs(torch.fft.fftshift(fft))

plt.figure(figsize=(25,5))

plt.subplot(1, 2, 1)
plt.plot(t, org_signal)
plt.grid()

plt.subplot(1, 2, 2)
plt.plot(w, fft_power)
plt.xlim(0, 120)
plt.grid()

plt.show()

for i in range(num_samples):
    sample = samples[i,0].to('cpu')
    
    
    fft = torch.fft.fft(sample)
    fft_power = torch.abs(torch.fft.fftshift(fft))

    plt.figure(figsize=(25,10))

    plt.subplot(2, 2, 1)
    plt.plot(t, sample)
    plt.grid()

    plt.subplot(2, 2, 2)
    plt.plot(w, fft_power)
    plt.xlim(0, 120)
    plt.grid()

    plt.subplot(2, 2, 3)
    plt.plot(t, sample - org_signal)
    plt.grid()

    res_fft = torch.fft.fft(sample - org_signal)
    res_fft_power = torch.abs(torch.fft.fftshift(res_fft))

    plt.subplot(2, 2, 4)
    plt.plot(w, res_fft_power)
    plt.xlim(0, 120)
    plt.grid()

    plt.show()
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
samples shape: torch.Size([1, 1, 4000])
No description has been provided for this image
No description has been provided for this image
In [110]:
# partial diffusion 예시
num_samples = 1
Lambda = 200

t = torch.linspace(0, 1, 1*fs)

s1 = torch.sin(10 * torch.pi * t).reshape(1, 1, -1)
s2 = torch.sin(210 * torch.pi * t).reshape(1, 1, -1)
original = (s1 + s2).to(device)

samples = partial_diffusion(model, original, Lambda=Lambda, num_samples=num_samples, signal_length=1*fs)
print("samples shape:", samples.shape)  # [num_samples, 1, 5*fs]

org_signal = original.reshape(-1).to('cpu')


w = torch.linspace(-fs/2, fs/2, 1*fs)
fft = torch.fft.fft(org_signal)
fft_power = torch.abs(torch.fft.fftshift(fft))

plt.figure(figsize=(25,5))

plt.subplot(1, 2, 1)
plt.plot(t, org_signal)
plt.grid()

plt.subplot(1, 2, 2)
plt.plot(w, fft_power)
plt.xlim(0, 120)
plt.grid()

plt.show()

for i in range(num_samples):
    sample = samples[i,0].to('cpu')
    
    
    fft = torch.fft.fft(sample)
    fft_power = torch.abs(torch.fft.fftshift(fft))

    plt.figure(figsize=(25,10))

    plt.subplot(2, 2, 1)
    plt.plot(t, sample)
    plt.grid()

    plt.subplot(2, 2, 2)
    plt.plot(w, fft_power)
    plt.xlim(0, 120)
    plt.grid()

    plt.subplot(2, 2, 3)
    plt.plot(t, sample - org_signal)
    plt.grid()

    res_fft = torch.fft.fft(sample - org_signal)
    res_fft_power = torch.abs(torch.fft.fftshift(res_fft))

    plt.subplot(2, 2, 4)
    plt.plot(w, res_fft_power)
    plt.xlim(0, 120)
    plt.grid()

    plt.show()
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
samples shape: torch.Size([1, 1, 4000])
No description has been provided for this image
No description has been provided for this image
In [124]:
from scipy.stats import norm
In [133]:
# partial diffusion 예시
num_samples = 1
Lambda = 200

t = torch.linspace(0, 1, 1*fs)

s1 = torch.sin(10 * torch.pi * t).reshape(1, 1, -1)
s2 = torch.sin(210 * torch.pi * t).reshape(1, 1, -1)
noise = 1 * torch.randn(t.shape)

original = (s1 + s2 + noise).to(device)

samples = partial_diffusion(model, original, Lambda=Lambda, num_samples=num_samples, signal_length=1*fs)
print("samples shape:", samples.shape)  # [num_samples, 1, 10*fs]

org_signal = original.reshape(-1).to('cpu')


w = torch.linspace(-fs/2, fs/2, 1*fs)
fft = torch.fft.fft(org_signal)
fft_power = torch.abs(torch.fft.fftshift(fft))

plt.figure(figsize=(25,5))

plt.subplot(1, 2, 1)
plt.plot(t, org_signal)
plt.grid()

plt.subplot(1, 2, 2)
plt.plot(w, fft_power)
plt.xlim(0, 120)
plt.grid()

plt.show()

for i in range(num_samples):
    sample = samples[i,0].to('cpu')
    
    
    fft = torch.fft.fft(sample)
    fft_power = torch.abs(torch.fft.fftshift(fft))

    plt.figure(figsize=(25,10))

    plt.subplot(2, 2, 1)
    plt.plot(t, sample)
    plt.grid()

    plt.subplot(2, 2, 2)
    plt.plot(w, fft_power)
    plt.xlim(0, 120)
    plt.grid()

    residual = sample - org_signal

    plt.subplot(2, 2, 3)
    plt.plot(t, residual)
    plt.grid()

    res_fft = torch.fft.fft(residual)
    res_fft_power = torch.abs(torch.fft.fftshift(res_fft))

    plt.subplot(2, 2, 4)
    plt.plot(w, res_fft_power)
    plt.xlim(0, 120)
    plt.grid()

    plt.show()

plt.figure(figsize=(12, 8))

gaus_x = torch.arange(-4, 4, 0.001)

plt.title(f'mean: {torch.mean(residual)}, std: {torch.std(residual)}')
plt.hist(residual, bins=200)
plt.plot(gaus_x, 130*norm.pdf(gaus_x, 0, 1))
plt.show()
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
samples shape: torch.Size([1, 1, 4000])
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
In [134]:
# partial diffusion 예시
num_samples = 1
Lambda = 400

t = torch.linspace(0, 1, 1*fs)

s1 = torch.sin(10 * torch.pi * t).reshape(1, 1, -1)
s2 = torch.sin(210 * torch.pi * t).reshape(1, 1, -1)
noise = 1 * torch.randn(t.shape)

original = (s1 + s2 + noise).to(device)

samples = partial_diffusion(model, original, Lambda=Lambda, num_samples=num_samples, signal_length=1*fs)
print("samples shape:", samples.shape)  # [num_samples, 1, 10*fs]

org_signal = original.reshape(-1).to('cpu')


w = torch.linspace(-fs/2, fs/2, 1*fs)
fft = torch.fft.fft(org_signal)
fft_power = torch.abs(torch.fft.fftshift(fft))

plt.figure(figsize=(25,5))

plt.subplot(1, 2, 1)
plt.plot(t, org_signal)
plt.grid()

plt.subplot(1, 2, 2)
plt.plot(w, fft_power)
plt.xlim(0, 120)
plt.grid()

plt.show()

for i in range(num_samples):
    sample = samples[i,0].to('cpu')
    
    
    fft = torch.fft.fft(sample)
    fft_power = torch.abs(torch.fft.fftshift(fft))

    plt.figure(figsize=(25,10))

    plt.subplot(2, 2, 1)
    plt.plot(t, sample)
    plt.grid()

    plt.subplot(2, 2, 2)
    plt.plot(w, fft_power)
    plt.xlim(0, 120)
    plt.grid()

    residual = sample - org_signal

    plt.subplot(2, 2, 3)
    plt.plot(t, residual)
    plt.grid()

    res_fft = torch.fft.fft(residual)
    res_fft_power = torch.abs(torch.fft.fftshift(res_fft))

    plt.subplot(2, 2, 4)
    plt.plot(w, res_fft_power)
    plt.xlim(0, 120)
    plt.grid()

    plt.show()

plt.figure(figsize=(12, 8))

gaus_x = torch.arange(-4, 4, 0.001)

plt.title(f'mean: {torch.mean(residual)}, std: {torch.std(residual)}')
plt.hist(residual, bins=200)
plt.plot(gaus_x, 130*norm.pdf(gaus_x, 0, 1))
plt.show()
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
samples shape: torch.Size([1, 1, 4000])
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
In [135]:
# partial diffusion 예시
num_samples = 1
Lambda = 600

t = torch.linspace(0, 1, 1*fs)

s1 = torch.sin(10 * torch.pi * t).reshape(1, 1, -1)
s2 = torch.sin(210 * torch.pi * t).reshape(1, 1, -1)
noise = 1 * torch.randn(t.shape)

original = (s1 + s2 + noise).to(device)

samples = partial_diffusion(model, original, Lambda=Lambda, num_samples=num_samples, signal_length=1*fs)
print("samples shape:", samples.shape)  # [num_samples, 1, 10*fs]

org_signal = original.reshape(-1).to('cpu')


w = torch.linspace(-fs/2, fs/2, 1*fs)
fft = torch.fft.fft(org_signal)
fft_power = torch.abs(torch.fft.fftshift(fft))

plt.figure(figsize=(25,5))

plt.subplot(1, 2, 1)
plt.plot(t, org_signal)
plt.grid()

plt.subplot(1, 2, 2)
plt.plot(w, fft_power)
plt.xlim(0, 120)
plt.grid()

plt.show()

for i in range(num_samples):
    sample = samples[i,0].to('cpu')
    
    
    fft = torch.fft.fft(sample)
    fft_power = torch.abs(torch.fft.fftshift(fft))

    plt.figure(figsize=(25,10))

    plt.subplot(2, 2, 1)
    plt.plot(t, sample)
    plt.grid()

    plt.subplot(2, 2, 2)
    plt.plot(w, fft_power)
    plt.xlim(0, 120)
    plt.grid()

    residual = sample - org_signal

    plt.subplot(2, 2, 3)
    plt.plot(t, residual)
    plt.grid()

    res_fft = torch.fft.fft(residual)
    res_fft_power = torch.abs(torch.fft.fftshift(res_fft))

    plt.subplot(2, 2, 4)
    plt.plot(w, res_fft_power)
    plt.xlim(0, 120)
    plt.grid()

    plt.show()

plt.figure(figsize=(12, 8))

gaus_x = torch.arange(-4, 4, 0.001)

plt.title(f'mean: {torch.mean(residual)}, std: {torch.std(residual)}')
plt.hist(residual, bins=200)
plt.plot(gaus_x, 130*norm.pdf(gaus_x, 0, 1))
plt.show()
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
samples shape: torch.Size([1, 1, 4000])
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
In [117]:
# 샘플링 예시
num_samples = 8
samples = sample_ddpm(model, num_samples=num_samples, signal_length=2*fs)
t = torch.linspace(0, 2, 2*fs)

print("samples shape:", samples.shape)


for i in range(num_samples):
    sample = samples[i,0].to('cpu')
    
    w = torch.linspace(-fs/2, fs/2, 2*fs)
    fft = torch.fft.fft(sample)
    fft_power = torch.abs(torch.fft.fftshift(fft))

    plt.figure(figsize=(25,5))

    plt.subplot(1, 2, 1)
    plt.plot(t, sample)
    plt.ylim(-4, 4)
    plt.grid()

    plt.subplot(1, 2, 2)
    plt.plot(w, fft_power)
    plt.xlim(0, 120)
    plt.grid()

    plt.show()
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
samples shape: torch.Size([8, 1, 8000])
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image

DDIM Sampling¶

In [114]:
@torch.no_grad()
def sample_ddim(model, num_samples=1, signal_length=10*fs):
    """
    DDPM reverse diffusion sampling
    """
    model.eval()
    model.to(device)

    # 1) x_T ~ N(0, I)
    x_t = torch.randn(num_samples, 1, signal_length).to(device)

    for i in reversed(range(T)): # i: T-1 down to 0
        # i는 파이썬 인덱스, 실제 t는 i+1
        t_val = torch.tensor([i+1]*num_samples, device=device)  # shape = [num_samples]
        
        if (i+1)%100 == 0 :
            plt.figure(figsize=(120,5))

            for j in range(num_samples):
                signal = x_t[j,0].to('cpu').numpy()

                plt.subplot(1, 10, j+1)
                plt.title(f'X_{(i+1)} Signal')
                plt.plot(signal)
                plt.ylim(-4, 4)
                
            plt.show()

        # 모델의 예측 노이즈
        eps = model(x_t, t_val)

        z = torch.randn(num_samples, 1, signal_length).to(device)

        # (주의) alpha_bar[i]는 t=i+1에 해당
        alpha_t = alpha[i]
        alpha_bar_t = alpha_bar[i]

        sigma_t = torch.sqrt((1.0 - alpha_t) * (1.0 - alpha_bar_t/alpha_t) / (1.0 - alpha_bar_t))

        # 역방향 공식 (DDIM)
        # x_{t-1} = 1/sqrt(alpha_t) (x_t - sqrt(1-alpha_bar_t)*eps) + sqrt(1 - alpha_bar_t/alpha_t - sigma_t^2)*eps + sigma_t * z
        one_over_sqrt_alpha_t = 1.0 / torch.sqrt(alpha_t)
        sqrt_one_minus_alpha_bar_t = torch.sqrt(1.0 - alpha_bar_t)
        sqrt_one_minus_alpha_t_minus_1_minus_sigma_t_square = torch.sqrt(1.0 - alpha_bar_t/alpha_t - sigma_t**2)

        if i != 0:
            x_prev = one_over_sqrt_alpha_t * (x_t - sqrt_one_minus_alpha_bar_t*eps) + sqrt_one_minus_alpha_t_minus_1_minus_sigma_t_square*eps + sigma_t*z

        # x_1 -> x_0
        if i == 0:
            # 역방향 공식 (DDPM)
            # x_{t-1} = 1/sqrt(alpha_t) (x_t - (1-alpha_t)/sqrt(1-alpha_bar_t)*eps)
            one_minus_alpha_t = 1.0 - alpha_t

            # reshape to match (batch_size, 1, 1)
            one_over_sqrt_alpha_t = one_over_sqrt_alpha_t.to(device)
            one_minus_alpha_t = one_minus_alpha_t.to(device)
            sqrt_one_minus_alpha_bar_t = sqrt_one_minus_alpha_bar_t.to(device)

            x_prev = one_over_sqrt_alpha_t * (x_t - (one_minus_alpha_t / sqrt_one_minus_alpha_bar_t) * eps)
            
        x_t = x_prev  # update

    # 최종 x_0 반환
    return x_t
In [115]:
# 샘플링 예시
num_samples = 5
samples = sample_ddim(model, num_samples=num_samples, signal_length=10*fs)
t = torch.linspace(0, 10, 10*fs)

print("samples shape:", samples.shape)


for i in range(num_samples):
    sample = samples[i,0].to('cpu')
    
    w = torch.linspace(-fs/2, fs/2, 10*fs)
    fft = torch.fft.fft(sample)
    fft_power = torch.abs(torch.fft.fftshift(fft))

    plt.figure(figsize=(25,5))

    plt.subplot(1, 2, 1)
    plt.plot(t, sample)
    plt.ylim(-4, 4)
    plt.grid()

    plt.subplot(1, 2, 2)
    plt.plot(w, fft_power)
    plt.xlim(0, 20)
    plt.grid()

    plt.show()
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
samples shape: torch.Size([5, 1, 40000])
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
In [ ]: